{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a4292c3b",
   "metadata": {},
   "source": [
    "# AutoMM for Semantic Segmentation - Quick Start\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/autogluon/autogluon/blob/master/docs/tutorials/multimodal/image_segmentation/beginner_semantic_seg.ipynb)\n",
    "[![Open In SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/autogluon/autogluon/blob/master/docs/tutorials/multimodal/image_segmentation/beginner_semantic_seg.ipynb)\n",
    "\n",
    "\n",
    "Semantic Segmentation is a computer vision task where the objective is to create a detailed pixel-wise segmentation map of an image, assigning each pixel to a specific class or object. This technology is crucial in various applications, such as in autonomous vehicles to identify vehicles, pedestrians, traffic signs, pavement, and other road features.\n",
    "\n",
    "The Segment Anything Model (SAM) is a foundational model pretrained on a vast dataset with 1 billion masks and 11 million images. While SAM performs exceptionally well on generic scenes, it encounters challenges when applied to specialized domains like remote sensing, medical imagery, agriculture, and manufacturing. Fortunately, AutoMM comes to the rescue by facilitating the fine-tuning of SAM on domain-specific data.\n",
    "\n",
    "In this easy-to-follow tutorial, we will guide you through the process of using AutoMM to fine-tune SAM. With just a single call to the `fit()` API, you can effortlessly train the model.\n",
    "\n",
    "\n",
    "## Prepare Data\n",
    "For demonstration purposes, we use the [Leaf Disease Segmentation](https://www.kaggle.com/datasets/sovitrath/leaf-disease-segmentation-with-trainvalid-split) from Kaggle. This dataset is a good example for automating disease detection in plants, especially for speeding up the plant pathology process. Segmenting specific regions on leaves or plants can be quite challenging, particularly when dealing with smaller diseased areas or various types of diseases.\n",
    "\n",
    "To begin, download and prepare the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa00faab-252f-44c9-b8f7-57131aa8251c",
   "metadata": {
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "!pip install autogluon.multimodal\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d28424b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "download_dir = './ag_automm_tutorial'\n",
    "zip_file = 'https://automl-mm-bench.s3.amazonaws.com/semantic_segmentation/leaf_disease_segmentation.zip'\n",
    "from autogluon.core.utils.loaders import load_zip\n",
    "load_zip.unzip(zip_file, unzip_dir=download_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66a6f5d5",
   "metadata": {},
   "source": [
    "Next, load the CSV files, ensuring that relative paths are expanded to facilitate correct data loading during both training and testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5609f0c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "dataset_path = os.path.join(download_dir, 'leaf_disease_segmentation')\n",
    "train_data = pd.read_csv(f'{dataset_path}/train.csv', index_col=0)\n",
    "val_data = pd.read_csv(f'{dataset_path}/val.csv', index_col=0)\n",
    "test_data = pd.read_csv(f'{dataset_path}/test.csv', index_col=0)\n",
    "image_col = 'image'\n",
    "label_col = 'label'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f0289f4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def path_expander(path, base_folder):\n",
    "    path_l = path.split(';')\n",
    "    return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])\n",
    "\n",
    "for per_col in [image_col, label_col]:\n",
    "    train_data[per_col] = train_data[per_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))\n",
    "    val_data[per_col] = val_data[per_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))\n",
    "    test_data[per_col] = test_data[per_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))\n",
    "    \n",
    "\n",
    "print(train_data[image_col].iloc[0])\n",
    "print(train_data[label_col].iloc[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38ba9123",
   "metadata": {},
   "source": [
    "Each Pandas DataFrame contains two columns: one for image paths and the other for corresponding groundtruth masks. Let's take a closer look at the training data DataFrame."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "191ba0df",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "train_data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77034c07",
   "metadata": {},
   "source": [
    "We can also visualize one image and its groundtruth mask."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "659ed670",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from autogluon.multimodal.utils import SemanticSegmentationVisualizer\n",
    "visualizer = SemanticSegmentationVisualizer()\n",
    "visualizer.plot_image(test_data.iloc[0]['image'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c706598",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "visualizer.plot_image(test_data.iloc[0]['label'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f93dcc50",
   "metadata": {},
   "source": [
    "## Zero Shot Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7740d88a",
   "metadata": {},
   "source": [
    "Now, let's see how well the pretrained SAM can segment the images. For this demonstration, we'll use the base SAM model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc053112",
   "metadata": {},
   "outputs": [],
   "source": [
    "from autogluon.multimodal import MultiModalPredictor\n",
    "predictor_zero_shot = MultiModalPredictor(\n",
    "    problem_type=\"semantic_segmentation\", \n",
    "    label=label_col,\n",
    "     hyperparameters={\n",
    "            \"model.sam.checkpoint_name\": \"facebook/sam-vit-base\",\n",
    "        },\n",
    "    num_classes=1, # forground-background segmentation\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68ff09f5",
   "metadata": {},
   "source": [
    "After initializing the predictor, you can perform inference directly. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cfddea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_zero_shot = predictor_zero_shot.predict({'image': [test_data.iloc[0]['image']]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8105f39c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "visualizer.plot_mask(pred_zero_shot)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4d2179d",
   "metadata": {},
   "source": [
    "It's worth noting that SAM without prompts outputs a rough leaf mask instead of disease masks due to its lack of context about the domain task. While SAM can perform better with proper click prompts, it might not be an ideal end-to-end solution for some applications that require a standalone model for deployment."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4dcc624",
   "metadata": {},
   "source": [
    "You can also conduct a zero-shot evaluation on the test data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c421486",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = predictor_zero_shot.evaluate(test_data, metrics=[\"iou\"])\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "778d4b0b",
   "metadata": {},
   "source": [
    "As expected, the test score of the zero-shot SAM is relatively low. Next, let's explore how to fine-tune SAM for enhanced performance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6640353",
   "metadata": {},
   "source": [
    "## Finetune SAM\n",
    "\n",
    "Initialize a new predictor and fit it with the training and validation data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96e31052",
   "metadata": {},
   "outputs": [],
   "source": [
    "from autogluon.multimodal import MultiModalPredictor\n",
    "import uuid\n",
    "save_path = f\"./tmp/{uuid.uuid4().hex}-automm_semantic_seg\"\n",
    "predictor = MultiModalPredictor(\n",
    "    problem_type=\"semantic_segmentation\", \n",
    "    label=\"label\",\n",
    "     hyperparameters={\n",
    "            \"model.sam.checkpoint_name\": \"facebook/sam-vit-base\",\n",
    "        },\n",
    "    path=save_path,\n",
    ")\n",
    "predictor.fit(\n",
    "    train_data=train_data,\n",
    "    tuning_data=val_data,\n",
    "    time_limit=180, # seconds\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "199ced03",
   "metadata": {},
   "source": [
    "Under the hood, we use [LoRA](https://arxiv.org/abs/2106.09685) for efficient fine-tuning. Note that, without hyperparameter customization, the huge SAM serves as the default model, which requires efficient fine-tuning in many cases."
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "After fine-tuning, evaluate SAM on the test data."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e4078bd",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "scores = predictor.evaluate(test_data, metrics=[\"iou\"])\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b42fc8e8",
   "metadata": {},
   "source": [
    "Thanks to the fine-tuning process, the test score has significantly improved."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c60aa65",
   "metadata": {},
   "source": [
    "To visualize the impact, let's examine the predicted mask after fine-tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f21c9030",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "pred = predictor.predict({'image': [test_data.iloc[0]['image']]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1597bcc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualizer.plot_mask(pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3defcf3",
   "metadata": {},
   "source": [
    "As evident from the results, the predicted mask is now much closer to the groundtruth. This demonstrates the effectiveness of using AutoMM to fine-tune SAM for domain-specific applications, enhancing its performance in tasks like leaf disease segmentation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67728302",
   "metadata": {},
   "source": [
    "## Save and Load\n",
    "\n",
    "The trained predictor is automatically saved at the end of `fit()`, and you can easily reload it.\n",
    "\n",
    "```{warning}\n",
    "\n",
    "`MultiModalPredictor.load()` uses `pickle` module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Never load data that could have come from an untrusted source, or that could have been tampered with. **Only load data you trust.**\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6df38a6e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "loaded_predictor = MultiModalPredictor.load(save_path)\n",
    "scores = loaded_predictor.evaluate(test_data, metrics=[\"iou\"])\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f650eda5",
   "metadata": {},
   "source": [
    "We can see the evaluation score is still the same as above, which means same model!\n",
    "\n",
    "## Other Examples\n",
    "\n",
    "You may go to [AutoMM Examples](https://github.com/autogluon/autogluon/tree/master/examples/automm) to explore other examples about AutoMM.\n",
    "\n",
    "## Customization\n",
    "To learn how to customize AutoMM, please refer to [Customize AutoMM](../advanced_topics/customization.ipynb)."
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
